import numpy as np
import time 
import logging
import os
import random
import torch
import torch.utils.data

import pandas as pd 
import csv

class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        data_source,
        split_file, 
        subsample,
        gt_filename,
    ):
        self.data_source = data_source 
        self.subsample = subsample
        self.split_file = split_file
        self.gt_filename = gt_filename

    def __len__(self):
        return NotImplementedError

    def __getitem__(self, idx):     
        return NotImplementedError

    def sample_pointcloud(self, csvfile, pc_size):
        f=pd.read_csv(csvfile, sep=',',header=None).values

        f = f[f[:,-1]==0][:,:3]

        if f.shape[0] < pc_size:
            pc_idx = np.random.choice(f.shape[0], pc_size)
        else:
            pc_idx = np.random.choice(f.shape[0], pc_size, replace=False)

        return torch.from_numpy(f[pc_idx]).float()

    def labeled_sampling(self, f, subsample, pc_size=1024, load_from_path=True, label='sdf'):  
        if load_from_path:
            if label == 'sdf':
                f = torch.from_numpy(np.load(f)['sdf_data'])
            elif label == 'grid':
                f = torch.from_numpy(np.load(f)['grid_data'])
        half = int(subsample / 2) 
        neg_tensor = f[f[:,-1]<0] 
        pos_tensor = f[f[:,-1]>0] 

        if pos_tensor.shape[0] < half:
            if pos_tensor.shape[0]==0:
                pos_idx = torch.randint(0, neg_tensor.shape[0], (half,))  
            else:
                pos_idx = torch.randint(0, pos_tensor.shape[0], (half,))
        else:
            pos_idx = torch.randperm(pos_tensor.shape[0])[:half]  

        if neg_tensor.shape[0] < half:
            if neg_tensor.shape[0]==0:
                neg_idx = torch.randperm(pos_tensor.shape[0])[:half] 
            else:
                neg_idx = torch.randint(0, neg_tensor.shape[0], (half,))
        else:
            neg_idx = torch.randperm(neg_tensor.shape[0])[:half]

        if pos_tensor.shape[0]==0:
            pos_sample = neg_tensor[pos_idx]
        else:
            pos_sample = pos_tensor[pos_idx]  

        if neg_tensor.shape[0]==0:
            neg_sample = pos_tensor[neg_idx]
        else:
            neg_sample = neg_tensor[neg_idx]  

        pc = f[f[:,-1]==0][:,:3]  
        pc_idx = torch.randperm(pc.shape[0])[:pc_size]  
        pc = pc[pc_idx]

        samples = torch.cat([pos_sample, neg_sample], 0)  

        return pc.float().squeeze(), samples[:,:3].float().squeeze(), samples[:, 3].float().squeeze() 


    def get_instance_filenames(self, data_source, split, gt_filename="sdf_data.csv", filter_modulation_path=None):
            
            do_filter = filter_modulation_path is not None 
            npzfiles = []
            class_name = data_source.split('/')[-3]
            for dataset in split:
                dataset = dataset.replace(".json", ".npz")
                instance_filename = os.path.join(data_source, dataset)

                if do_filter:
                    dataset = dataset.split('.')[0]
                    mod_file = os.path.join(filter_modulation_path, class_name, dataset, "latent.txt")

                    if not os.path.isfile(mod_file):
                        continue

                if not os.path.isfile(instance_filename):
                    logging.warning("Requested non-existent file '{}'".format(instance_filename))
                    continue

                npzfiles.append(instance_filename)
            return npzfiles
